/*
  Copyright 2010-2012 Patrik Jonsson, sunrise@familjenjonsson.org

  This file is part of Sunrise.

  Sunrise is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 3 of the License, or
  (at your option) any later version.

  Sunrise is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.
  
*/

#include "analytic_case.h"
#include "grain_model.h"
#include <cmath>

class datamaker {
public:
  typedef T_grid::T_grid_impl::T_data T_data;
  typedef cell_tracker<T_data> T_cell_tracker;
  
  virtual T_data get_data(T_float, const T_cell_tracker&) const=0;
};

class stellar_datamaker : public datamaker {
public:
  const T_lambda& emissivity_;
  const T_lambda* lptr_;
  stellar_datamaker(const T_lambda& em, const T_lambda* lp) : 
    emissivity_(em), lptr_(lp) {};

  virtual T_data get_data(T_float density, 
			  const T_cell_tracker& c) const {
    return T_data
      (T_data::T_emitter(T_lambda(emissivity_*density*c.volume()), lptr_, c),
       T_data::T_absorber());
  };
};

class dust_datamaker : public datamaker {
public:
  int n_lambda_;

  dust_datamaker(int nl) : n_lambda_(nl) {};

  virtual T_data get_data(T_float density, 
			  const T_cell_tracker& c) const {
    T_densities d(1); d=density;
    return T_data(T_data::T_emitter(),
		  T_data::T_absorber(d,vec3d(0,0,0), T_lambda(n_lambda_)));
  };
};

/** The factory makes either the dust grid or the emission grid. This
    is decided by the datamaker, which either gets the emissivity or
    the dust density from the formula. */
class validation_factory {
public:
  typedef T_grid::T_grid_impl::T_data T_data;
  typedef refinement_accuracy_data<T_data> T_racc;
  typedef cell_tracker<T_data> T_cell_tracker;

  T_float rho0_;
  const T_densities n_tol_, stdev_;
  const T_float hr_, hz_, rmax_, zmax_;
  const int maxlevel, n_threads_;
  const datamaker& dm_;

  // returns the density according to the set distribution
  T_float get_rho(const T_cell_tracker& c) {
    // return 0 if outside of radius rmax
    const T_float r = sqrt(c.getcenter()[0]*c.getcenter()[0]+
			   c.getcenter()[1]*c.getcenter()[1]);
    const T_float z = std::abs(c.getcenter()[2]);

    const T_float rho = ((r < rmax_) && (z < zmax_)) ?
      rho0_*exp(-r/hr_ - z/hz_) : 0;
    return rho;
  };

  validation_factory (const datamaker& dm, T_float rho0,
		      T_float hr, T_float hz, T_float rmax, T_float zmax,
		      const T_densities& n_tol, const T_densities& stdev, 
		      int ml, int nt):
    dm_(dm), rho0_ (rho0), hr_(hr), hz_(hz), rmax_(rmax), zmax_(zmax),
    maxlevel(ml), stdev_(stdev), n_tol_(n_tol), n_threads_(nt) {};
  
  bool refine_cell_p (const T_cell_tracker& c) {
    const int level = c.code().level();
    return (level<maxlevel);
      
    // cell is refined to keep column density below n_tol
    T_float rho = get_rho(c);
    T_float n2 (dot(c.getsize(), c.getsize())*rho*rho);
    bool refine = all( n2 > (n_tol_*n_tol_));
    //if(!refine) cout << "No refinement level " << level << " at " << c.getcenter() << c.getsize() << endl;
    return refine;
  };

  bool unify_cell_p (const T_cell_tracker& c, const T_racc& racc) {
    return false;
    // need to delegate this to the datamaker since the stellar grid has unitinizalized absorbers.
    const T_densities mean (racc.sum.get_absorber().densities()/racc.n);
    const T_densities stdev2 (racc.sumsq.get_absorber().densities()/racc.n- mean*mean);
    return all ((mean==0) || ((stdev2/(mean*mean)<stdev_*stdev_) && (mean<n_tol_)));
  };

  T_data get_data (const T_cell_tracker& c) {
    T_data d(dm_.get_data(get_rho(c), c));
    return d;
  }
    
  virtual int n_threads () {return n_threads_;};
};


void setup (po::variables_map& opt)
{
  T_unit_map units;
  units ["length"] = "kpc";
  units ["mass"] = "Msun";
  units ["luminosity"] = "W";
  units ["wavelength"] = "m";
  units ["L_lambda"] = "W/m";

  const int maxlevel= opt["maxlevel"].as<int>();
  const int n_threads = opt["n_threads"].as<int>();
  const T_float tau_f = opt["tau_f"].as<T_float>();
  const T_float tau_tol = opt["tau_tol"].as<T_float>();
  const T_float cameradist=1e4; //10Mpc

  T_densities stdev_tol(1); stdev_tol = 0.0;
  const vec3d grid_extent(12,12,2);  
  const T_float hrd=3.0;
  const T_float hzd=0.15;
  const T_float rdmax=12.0;
  const T_float zdmax=0.9;

  const T_float hrs=3.0;
  const T_float hzs=0.30;
  const T_float rsmax=12.0;
  const T_float zsmax=1.8;
  const T_float L_tot=1.919e37;

  const int npix=opt["npix"].as<int>();
  const T_float cam_fov=24;

  // create wavelengths
  array_1 lambda(logspace(1e-7,1e-3,81));
  std::vector<T_float> lambdas(lambda.begin(), lambda.end());

  // we need a preferences object to pass parameters to grain model
  Preferences p;
  p.setValue("grain_data_directory", 
	     word_expand(opt["grain_data_dir"].as<string>())[0]);
  //p.setValue("grain_model", grain_model);
  p.setValue("wd01_parameter_set", opt["wd01_parameter_set"].as<string>());
  p.setValue("template_pah_fraction", 0.5);
  p.setValue("use_dl07_opacities", true);
  p.setValue("n_threads", n_threads);
  p.setValue("use_grain_temp_lookup", true);
  p.setValue("use_cuda", false);

  // load grain model
  T_dust_model::T_scatterer_vector sv;
  sv.push_back(boost::shared_ptr<T_scatterer> 
	       (new baes_grain_model<polychromatic_scatterer_policy, mcrx_rng_policy>(word_expand(p.getValue("grain_data_directory", string()))[0], p, units)));

  T_dust_model model (sv);
  model.set_wavelength(lambda);


  // we set the 1um wavelength manually to avoid ambiguities
  const int lambda1um=20;

  // find opacity normalization for whatever dust we're using
  const T_float unit_opacity_rho = 1.0/sv[0]->opacity()(lambda1um);
  T_float rho0; 

  // calculate central density from face-on optical depth
  rho0 = tau_f*unit_opacity_rho/(2*hzd*(1-exp(-zdmax/hzd)));
  cout << "Unit optical depth column density is " << unit_opacity_rho << endl;
  cout << "Face-on optical depth is " << tau_f << endl;
  cout << "Central density is " << rho0 << endl;
  T_densities n_tol(1); n_tol = tau_tol*unit_opacity_rho;
  stdev_tol = stdev_tol;
  cout << "Stdev tolerance is " << stdev_tol << endl;
  cout << "Column density tolerance is " << n_tol << endl;

  // create factory
  dust_datamaker ddm(lambda.size());
  validation_factory factory (ddm, rho0, hrd, hzd, rdmax, zdmax, 
			      n_tol, stdev_tol, maxlevel, n_threads);
  
  // build stellar grid
  cout << "\nBuilding stellar emission grid" << endl;
  blackbody bb(3000,1);
  T_lambda rho0s(lambda.size());
  rho0s=bb.emission(lambda);

  T_lambda long_stdev_tol(lambda.size()); long_stdev_tol = stdev_tol(0);
  T_lambda long_n_tol(lambda.size()); long_n_tol=1e300;

  stellar_datamaker sdm(rho0s, &lambda);
  validation_factory stellar_factory (sdm, L_tot, hrs, hzs, 
				      rsmax, zsmax, long_n_tol, long_stdev_tol,
				      maxlevel, n_threads);

  // create the adaptive grid
  boost::shared_ptr<adaptive_grid<T_grid::T_grid_impl::T_data> > stellargrid
    (new adaptive_grid<T_grid::T_grid_impl::T_data>  
     (-grid_extent, grid_extent, stellar_factory));

  // now we need to make a emission collection with the grid_cell emitters
  std::vector<emission<polychromatic_policy, T_rng_policy>* > emitters;
  array_1 L_bol(stellargrid->n_cells()); L_bol=0;
  int cc=0;
  // step 1 is to sum up total luminosity, reset cell pointers, and
  // normalize the emitters
  for (T_grid::T_grid_impl::iterator c=stellargrid->begin(), 
	 e=stellargrid->end(); 
       c!=e; ++c, ++cc) {
    // CRUCIAL: restore the cell pointers for the grid_cell_emission objects
    // this is a flaw in the design of the factory and grid_cell_emission object
    c->data()->get_emitter().set_cell(c);
    emitters.push_back(&c->data()->get_emitter());
    c->data()->get_emitter().get_emission() *= c.volume();
    const T_float L=integrate_quantity(c->data()->get_emitter().get_emission(), lambda, false);
    L_bol(cc)=L;
  }

  cout << "Total L_bol " << sum(L_bol) << endl;

  // Step 2 is to make sure we get correct total luminosity when
  // sampling. (see ir_grid::normalize_for_sampling(). We also fix any
  // discretization error so we have exactly the requested luminosity.
  L_bol *= L_tot/sum(L_bol);
  cc=0;
  for (T_grid::T_grid_impl::iterator c=stellargrid->begin(), 
	 e=stellargrid->end(); 
       c!=e; ++c, ++cc) {
    c->data()->get_emitter().get_emission() *= L_tot/L_bol(cc);
  }

  boost::shared_ptr<T_emission> em;
  em.reset(new emission_collection<polychromatic_policy, cumulative_sampling, T_rng_policy> (emitters, L_bol));
  
  cout << "\nSetting up emergence" << endl;
  std::vector<std::pair<T_float, T_float> > cam_pos;
  const T_float phi=0.3;
  cam_pos.push_back(std:: make_pair (0.0, phi) );
  cam_pos.push_back(std:: make_pair (60./180*constants::pi, phi) );
  cam_pos.push_back(std:: make_pair (80./180*constants::pi, phi) );
  cam_pos.push_back(std:: make_pair (87./180*constants::pi, phi) );
  cam_pos.push_back(std:: make_pair (90./180*constants::pi, phi) );
  boost::shared_ptr<T_emergence> cameras
    (new T_emergence(cam_pos, cameradist, cam_fov, npix));

  run_case(opt, 
	   factory,
	   units,
	   lambda,
	   -grid_extent,
	   grid_extent,
	   *em,
	   *cameras,
	   model); 
}

void pre_shoot(T_xfer&) {};

std::string description()
{
  return "Baes et al. RT validation testcase.";
}

void add_custom_options(po::options_description& desc)
{
  desc.add_options()
    ("tau_f", po::value<T_float>()->default_value(1),
     "V-band face-on optical depth")
    ;
}

void print_custom_options(const po::variables_map& opt)
{
  std::cout
       << "tau_f = " << opt["tau_f"].as<T_float>() << '\n'
    ;
}
